{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 排序"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.1 sort：返回逆序排序后的Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.random.shuffle(tf.range(6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=4, shape=(6,), dtype=int32, numpy=array([3, 1, 5, 2, 0, 4])>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=17, shape=(6,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5])>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sort(a)  # 默认是顺序排列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=30, shape=(6,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5])>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sort(a, direction='ASCENDING')  # 默认顺序排列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=40, shape=(6,), dtype=int32, numpy=array([5, 4, 3, 2, 1, 0])>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sort(a, direction='DESCENDING')  # 指定逆序排列"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也对多维Tensor排序，当对多维Tensor进行排序时，可以通过axis参数指定需要排序的维度，默认axis默认值为-1，也就是对最后一维进行排序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = tf.random.uniform([3, 3], minval=1, maxval=10,dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=46, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[1, 2, 1],\n",
       "       [9, 6, 3],\n",
       "       [6, 4, 1]])>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=59, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[1, 1, 2],\n",
       "       [3, 6, 9],\n",
       "       [1, 4, 6]])>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sort(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=91, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[1, 2, 1],\n",
       "       [6, 4, 1],\n",
       "       [9, 6, 3]])>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sort(b,axis=0)  # 通过axis参数指定第一维度，也就是列进行排序"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.2 argsort：返回排序后的索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=4, shape=(6,), dtype=int32, numpy=array([1, 3, 2, 4, 5, 0])>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=125, shape=(6,), dtype=int32, numpy=array([5, 0, 2, 1, 3, 4])>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argsort(a, direction='ASCENDING') # 返回排序之后的索引组成的Tensor, 默认是顺序排列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=136, shape=(6,), dtype=int32, numpy=array([4, 3, 1, 2, 0, 5])>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argsort(a, direction='DESCENDING') # n逆序排列"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以通过axis参数指定需要排序的维度，默认获取-1维度排序后索引："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=46, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[1, 2, 1],\n",
       "       [9, 6, 3],\n",
       "       [6, 4, 1]])>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=134, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[0, 2, 1],\n",
       "       [2, 1, 0],\n",
       "       [2, 1, 0]])>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argsort(b)  # 默认对最后一维度排序，也就是以行为单位排序"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=149, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[0, 0, 0],\n",
       "       [2, 2, 2],\n",
       "       [1, 1, 1]])>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argsort(b,axis=0)  # 指定第一维度进行排序，也就是以列为单位进行排序"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "返回的张量中，每一个元素表示b中原来元素在该行中的索引。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.3 top_k：返回逆序排序后的前$k$个元素组成的Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "sort()方法和argsort()方法都是对给定Tensor的所有元素进行排序，在某些情况下如果我们只是要获取排序的前几个元素，这时候使用sort()或argsort()方法就有些浪费时间了，这时候可以使用top_k()方法。top_k()方法可以指定获取前k个元素。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：top_k()方法在tf.math模块中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=4, shape=(6,), dtype=int32, numpy=array([3, 1, 5, 2, 0, 4])>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_2 = tf.math.top_k(a, 2)  # 获取排序后前两位"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TopKV2(values=<tf.Tensor: id=153, shape=(2,), dtype=int32, numpy=array([5, 4])>, indices=<tf.Tensor: id=154, shape=(2,), dtype=int32, numpy=array([2, 5])>)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上述输出可以看到，top_k()方法返回的是一个TopKV2类型对象，内部包含两部分数据：第一部分是排序后的真实数据[5, 4]，可以通过TopKV2对象的values属性获取；第二部分是排序后数据所在原Tensor中的索引[2, 5]，可以通过TopKV2对象的indices获取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=153, shape=(2,), dtype=int32, numpy=array([5, 4])>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_2.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=154, shape=(2,), dtype=int32, numpy=array([2, 5])>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_2.indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于高维Tensor也是一样的:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=152, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[7, 9, 7],\n",
       "       [4, 3, 1],\n",
       "       [1, 1, 6]])>"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TopKV2(values=<tf.Tensor: id=211, shape=(3, 2), dtype=int32, numpy=\n",
       "array([[9, 7],\n",
       "       [4, 3],\n",
       "       [6, 1]])>, indices=<tf.Tensor: id=212, shape=(3, 2), dtype=int32, numpy=\n",
       "array([[1, 0],\n",
       "       [0, 1],\n",
       "       [2, 0]])>)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.math.top_k(b, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：top_k()方法只能对最后一维度进行排序。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 最小值、最大值、平均值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.1 reduce_min、reduce_max、reduce_mean"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（1）reduce_min()：求最小值**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.random.uniform([3, 3], minval=1, maxval=10, dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=162, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[4, 9, 5],\n",
       "       [8, 6, 1],\n",
       "       [8, 7, 1]])>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不指定维度时，获取整个Tensor的最小值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=169, shape=(), dtype=int32, numpy=1>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_min(a)  # 最小值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过axis参数可以对指定维度求最小值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=172, shape=(3,), dtype=int32, numpy=array([4, 6, 1])>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " tf.reduce_min(a, axis=0)  # 求指定维度的最小值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（2）reduce_max():求最大值**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=175, shape=(), dtype=int32, numpy=9>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_max(a)  # 最大值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=190, shape=(3,), dtype=int32, numpy=array([9, 8, 8])>"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " tf.reduce_max(a, axis=-1)  # 求最后一维度的最大值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（3）reduce_mean():求平均值**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不指定维度时，求整个Tensor所有元素的平均值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=227, shape=(), dtype=int32, numpy=4>"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " tf.reduce_mean(a)  # 整个Tensor所有元素的平均值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=196, shape=(3,), dtype=int32, numpy=array([6, 7, 2])>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_mean(a, axis=0)  # 求第一维度（行）均值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面求均值的例子中，因为Tensor的dtype为int32，所以求出来的均值也是int32，而不是浮点型。如果需要求浮点型的均值，就需要将a的类型先转换为float32："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=200, shape=(3,), dtype=float32, numpy=array([6.6666665, 7.3333335, 2.3333333], dtype=float32)>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_mean(tf.cast(a, tf.float32), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.2 argmin()、argmax()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "argmin()、argmax()返回最大值最小值的索引组成的Tensor。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（1）argmin():求最小值索引**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.random.uniform([3,3],minval=1, maxval=10, dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=205, shape=(3, 3), dtype=int32, numpy=\n",
       "array([[5, 6, 1],\n",
       "       [3, 7, 2],\n",
       "       [7, 1, 6]])>"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = tf.random.uniform([3,3,3],minval=1, maxval=10, dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=210, shape=(3, 3, 3), dtype=int32, numpy=\n",
       "array([[[5, 4, 7],\n",
       "        [4, 3, 9],\n",
       "        [5, 3, 6]],\n",
       "\n",
       "       [[9, 5, 3],\n",
       "        [3, 2, 7],\n",
       "        [5, 6, 1]],\n",
       "\n",
       "       [[9, 9, 5],\n",
       "        [5, 4, 4],\n",
       "        [7, 1, 1]]])>"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=213, shape=(3,), dtype=int64, numpy=array([1, 2, 0], dtype=int64)>"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argmin(a)  # 默认是第0维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=216, shape=(3, 3), dtype=int64, numpy=\n",
       "array([[0, 0, 1],\n",
       "       [1, 1, 2],\n",
       "       [0, 2, 1]], dtype=int64)>"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argmin(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于shape为（3, 3）的Tensor，argmin(a)返回的是shape为（3,）的Tensor，因为没有指定比较的维度，默认比较的是第0维度的元素，也就是每一列数据；对于shape为（3，3,3）的Tensor，argmin(a)返回的是shape为（3,3）的Tensor，默认比较的是第0维度的元素，也就是每一块对应位置的元素，例如第一块的5、第二块的9、第三块的9比较，第一块的5最小，索引为0，所以返回的Tensor中第一个元素是0。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：argmin()方法在没有指定维度时，默认返回的是第0维度最小值的索引，这与reducemin()方法不同，reducemin()方法在没有指定维度是是返回整个Tensor中所有元素中的最小值。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（2）argmax():求最大值索引**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.random.uniform([3,3,3],minval=1, maxval=10, dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=221, shape=(3, 3, 3), dtype=int32, numpy=\n",
       "array([[[1, 2, 7],\n",
       "        [9, 3, 3],\n",
       "        [5, 4, 8]],\n",
       "\n",
       "       [[8, 5, 1],\n",
       "        [2, 6, 5],\n",
       "        [2, 1, 2]],\n",
       "\n",
       "       [[8, 9, 7],\n",
       "        [3, 3, 9],\n",
       "        [7, 7, 2]]])>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=233, shape=(3, 3), dtype=int64, numpy=\n",
       "array([[1, 2, 0],\n",
       "       [0, 1, 2],\n",
       "       [2, 2, 0]], dtype=int64)>"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argmax(a, axis=0)  # 第一维度，也就是每一块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=236, shape=(3, 3), dtype=int64, numpy=\n",
       "array([[2, 0, 2],\n",
       "       [0, 1, 0],\n",
       "       [1, 2, 0]], dtype=int64)>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argmax(a, axis=2)  # 第三维度，也就是每一行"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "study_python",
   "language": "python",
   "name": "study_python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
